{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "# Sequence classification model for IMDB Sentiment Analysis\n",
    "(c) Deniz Yuret, 2019\n",
    "* Objectives: Learn the structure of the IMDB dataset and train a simple RNN model.\n",
    "* Prerequisites: [RNN models](60.rnn.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set display width, load packages, import symbols\n",
    "ENV[\"COLUMNS\"] = 72\n",
    "using Statistics: mean\n",
    "using IterTools: ncycle\n",
    "using Knet: Knet, AutoGrad, RNN, param, dropout, minibatch, nll, accuracy, progress!, adam, save, load, gc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0e-8"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Set constants for the model and training\n",
    "EPOCHS=3          # Number of training epochs\n",
    "BATCHSIZE=64      # Number of instances in a minibatch\n",
    "EMBEDSIZE=125     # Word embedding size\n",
    "NUMHIDDEN=100     # Hidden layer size\n",
    "MAXLEN=150        # maximum size of the word sequence, pad shorter sequences, truncate longer ones\n",
    "VOCABSIZE=30000   # maximum vocabulary size, keep the most frequent 30K, map the rest to UNK token\n",
    "NUMCLASS=2        # number of output classes\n",
    "DROPOUT=0.5       # Dropout rate\n",
    "LR=0.001          # Learning rate\n",
    "BETA_1=0.9        # Adam optimization parameter\n",
    "BETA_2=0.999      # Adam optimization parameter\n",
    "EPS=1e-08         # Adam optimization parameter"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load and view data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "imdb"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "include(Knet.dir(\"data\",\"imdb.jl\"))   # defines imdb loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "slideshow": {
     "slide_type": "skip"
    }
   },
   "outputs": [
    {
     "data": {
      "text/latex": [
       "\\begin{verbatim}\n",
       "imdb()\n",
       "\\end{verbatim}\n",
       "Load the IMDB Movie reviews sentiment classification dataset from https://keras.io/datasets and return (xtrn,ytrn,xtst,ytst,dict) tuple.\n",
       "\n",
       "\\section{Keyword Arguments:}\n",
       "\\begin{itemize}\n",
       "\\item url=https://s3.amazonaws.com/text-datasets: where to download the data (imdb.npz) from.\n",
       "\n",
       "\n",
       "\\item dir=Pkg.dir(\"Knet/data\"): where to cache the data.\n",
       "\n",
       "\n",
       "\\item maxval=nothing: max number of token values to include. Words are ranked by how often they occur (in the training set) and only the most frequent words are kept. nothing means keep all, equivalent to maxval = vocabSize + pad + stoken.\n",
       "\n",
       "\n",
       "\\item maxlen=nothing: truncate sequences after this length. nothing means do not truncate.\n",
       "\n",
       "\n",
       "\\item seed=0: random seed for sample shuffling. Use system seed if 0.\n",
       "\n",
       "\n",
       "\\item pad=true: whether to pad short sequences (padding is done at the beginning of sequences). pad\\_token = maxval.\n",
       "\n",
       "\n",
       "\\item stoken=true: whether to add a start token to the beginning of each sequence. start\\_token = maxval - pad.\n",
       "\n",
       "\n",
       "\\item oov=true: whether to replace words >= oov\\emph{token with oov}token (the alternative is to skip them). oov\\_token = maxval - pad - stoken.\n",
       "\n",
       "\\end{itemize}\n"
      ],
      "text/markdown": [
       "```\n",
       "imdb()\n",
       "```\n",
       "\n",
       "Load the IMDB Movie reviews sentiment classification dataset from https://keras.io/datasets and return (xtrn,ytrn,xtst,ytst,dict) tuple.\n",
       "\n",
       "# Keyword Arguments:\n",
       "\n",
       "  * url=https://s3.amazonaws.com/text-datasets: where to download the data (imdb.npz) from.\n",
       "  * dir=Pkg.dir(\"Knet/data\"): where to cache the data.\n",
       "  * maxval=nothing: max number of token values to include. Words are ranked by how often they occur (in the training set) and only the most frequent words are kept. nothing means keep all, equivalent to maxval = vocabSize + pad + stoken.\n",
       "  * maxlen=nothing: truncate sequences after this length. nothing means do not truncate.\n",
       "  * seed=0: random seed for sample shuffling. Use system seed if 0.\n",
       "  * pad=true: whether to pad short sequences (padding is done at the beginning of sequences). pad_token = maxval.\n",
       "  * stoken=true: whether to add a start token to the beginning of each sequence. start_token = maxval - pad.\n",
       "  * oov=true: whether to replace words >= oov*token with oov*token (the alternative is to skip them). oov_token = maxval - pad - stoken.\n"
      ],
      "text/plain": [
       "\u001b[36m  imdb()\u001b[39m\n",
       "\n",
       "  Load the IMDB Movie reviews sentiment classification dataset from\n",
       "  https://keras.io/datasets and return (xtrn,ytrn,xtst,ytst,dict)\n",
       "  tuple.\n",
       "\n",
       "\u001b[1m  Keyword Arguments:\u001b[22m\n",
       "\u001b[1m  ≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡\u001b[22m\n",
       "\n",
       "    •    url=https://s3.amazonaws.com/text-datasets: where to\n",
       "        download the data (imdb.npz) from.\n",
       "\n",
       "    •    dir=Pkg.dir(\"Knet/data\"): where to cache the data.\n",
       "\n",
       "    •    maxval=nothing: max number of token values to include.\n",
       "        Words are ranked by how often they occur (in the training\n",
       "        set) and only the most frequent words are kept. nothing\n",
       "        means keep all, equivalent to maxval = vocabSize + pad +\n",
       "        stoken.\n",
       "\n",
       "    •    maxlen=nothing: truncate sequences after this length.\n",
       "        nothing means do not truncate.\n",
       "\n",
       "    •    seed=0: random seed for sample shuffling. Use system seed\n",
       "        if 0.\n",
       "\n",
       "    •    pad=true: whether to pad short sequences (padding is done\n",
       "        at the beginning of sequences). pad_token = maxval.\n",
       "\n",
       "    •    stoken=true: whether to add a start token to the beginning\n",
       "        of each sequence. start_token = maxval - pad.\n",
       "\n",
       "    •    oov=true: whether to replace words >= oov\u001b[4mtoken with\n",
       "        oov\u001b[24mtoken (the alternative is to skip them). oov_token =\n",
       "        maxval - pad - stoken."
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@doc imdb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "┌ Info: Loading IMDB...\n",
      "└ @ Main /home/deniz/.julia/dev/Knet/data/imdb.jl:57\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  6.632559 seconds (26.86 M allocations: 1.400 GiB, 8.81% gc time)\n"
     ]
    }
   ],
   "source": [
    "@time (xtrn,ytrn,xtst,ytst,imdbdict)=imdb(maxlen=MAXLEN,maxval=VOCABSIZE);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "25000-element Array{Array{Int32,1},1}\n",
      "25000-element Array{Int8,1}\n",
      "25000-element Array{Array{Int32,1},1}\n",
      "25000-element Array{Int8,1}\n",
      "Dict{String,Int32} with 88584 entries\n"
     ]
    }
   ],
   "source": [
    "println.(summary.((xtrn,ytrn,xtst,ytst,imdbdict)));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1×150 LinearAlgebra.Adjoint{Int32,Array{Int32,1}}:\n",
       " 30000  30000  30000  30000  29999  …  437  33  67  4389  65  19  228"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Words are encoded with integers\n",
    "rand(xtrn)'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1×25000 LinearAlgebra.Adjoint{Int64,Array{Int64,1}}:\n",
       " 150  150  150  150  150  150  150  …  150  150  150  150  150  150"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Each word sequence is padded or truncated to length 150\n",
    "length.(xtrn)'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "reviewstring (generic function with 2 methods)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Define a function that can print the actual words:\n",
    "imdbvocab = Array{String}(undef,length(imdbdict))\n",
    "for (k,v) in imdbdict; imdbvocab[v]=k; end\n",
    "imdbvocab[VOCABSIZE-2:VOCABSIZE] = [\"<unk>\",\"<s>\",\"<pad>\"]\n",
    "function reviewstring(x,y=0)\n",
    "    x = x[x.!=VOCABSIZE] # remove pads\n",
    "    \"\"\"$((\"Sample\",\"Negative\",\"Positive\")[y+1]) review:\\n$(join(imdbvocab[x],\" \"))\"\"\"\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Negative review:\n",
      "money to see i won't comment on the story itself it's a wonderful classic but here it feels like a soap opera to start with the acting except for eric bana is soap opera quality i've always been a fan of brad pitt but here every actor on the bold and the beautiful puts him to shame the camera action doesn't help either how it lingers on him when he's thinking it just takes me back to brooke <unk> days in the lab peter o'toole has either had a really bad plastic surgery or he is desperately in need of one either way he looks more like linda evans than linda evans and to end my comments diane kruger is a cute girl but she sure is no helen of troy peterson should rather have chosen saffron burrows for the role since elizabeth taylor would be rather miscast by now\n"
     ]
    }
   ],
   "source": [
    "# Hit Ctrl-Enter to see random reviews:\n",
    "r = rand(1:length(xtrn))\n",
    "println(reviewstring(xtrn[r],ytrn[r]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1×25000 LinearAlgebra.Adjoint{Int8,Array{Int8,1}}:\n",
       " 2  1  1  2  2  2  1  1  2  1  2  …  2  2  1  1  2  2  2  2  1  1  1"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Here are the labels: 1=negative, 2=positive\n",
    "ytrn'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "source": [
    "## Define the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "struct SequenceClassifier; input; rnn; output; pdrop; end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SequenceClassifier"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "SequenceClassifier(input::Int, embed::Int, hidden::Int, output::Int; pdrop=0) =\n",
    "    SequenceClassifier(param(embed,input), RNN(embed,hidden,rnnType=:gru), param(output,hidden), pdrop)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "function (sc::SequenceClassifier)(input)\n",
    "    embed = sc.input[:, permutedims(hcat(input...))]\n",
    "    embed = dropout(embed,sc.pdrop)\n",
    "    hidden = sc.rnn(embed)\n",
    "    hidden = dropout(hidden,sc.pdrop)\n",
    "    return sc.output * hidden[:,:,end]\n",
    "end\n",
    "\n",
    "(sc::SequenceClassifier)(input,output) = nll(sc(input),output)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(390, 390)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dtrn = minibatch(xtrn,ytrn,BATCHSIZE;shuffle=true)\n",
    "dtst = minibatch(xtst,ytst,BATCHSIZE)\n",
    "length.((dtrn,dtst))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "trainresults (generic function with 1 method)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# For running experiments\n",
    "function trainresults(file,maker; o...)\n",
    "    if (print(\"Train from scratch? \"); readline()[1]=='y')\n",
    "        model = maker()\n",
    "        progress!(adam(model,ncycle(dtrn,EPOCHS);lr=LR,beta1=BETA_1,beta2=BETA_2,eps=EPS))\n",
    "        Knet.save(file,\"model\",model)\n",
    "        GC.gc(true) # To save gpu memory\n",
    "    else\n",
    "        isfile(file) || download(\"http://people.csail.mit.edu/deniz/models/tutorial/$file\",file)\n",
    "        model = Knet.load(file,\"model\")\n",
    "    end\n",
    "    return model\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "maker (generic function with 1 method)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "maker() = SequenceClassifier(VOCABSIZE,EMBEDSIZE,NUMHIDDEN,NUMCLASS,pdrop=DROPOUT)\n",
    "# model = maker()\n",
    "# nll(model,dtrn), nll(model,dtst), accuracy(model,dtrn), accuracy(model,dtst)\n",
    "# (0.69312066f0, 0.69312423f0, 0.5135817307692307, 0.5096153846153846)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train from scratch? stdin> y\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "┣████████████████████┫ [100.00%, 1170/1170, 00:21/00:21, 54.84i/s] \n"
     ]
    }
   ],
   "source": [
    "model = trainresults(\"imdbmodel132.jld2\",maker);\n",
    "# ┣████████████████████┫ [100.00%, 1170/1170, 00:15/00:15, 76.09i/s]\n",
    "# nll(model,dtrn), nll(model,dtst), accuracy(model,dtrn), accuracy(model,dtst)\n",
    "# (0.05217469f0, 0.3827392f0, 0.9865785256410257, 0.8576121794871795)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Playground"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "str2ids (generic function with 1 method)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictstring(x)=\"\\nPrediction: \" * (\"Negative\",\"Positive\")[argmax(Array(vec(model([x]))))]\n",
    "UNK = VOCABSIZE-2\n",
    "str2ids(s::String)=[(i=get(imdbdict,w,UNK); i>=UNK ? UNK : i) for w in split(lowercase(s))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Negative review:\n",
      "fest and draw my applause br br i love philosophical films this isn't one of them anyone who is amazed at the depths of intellect <unk> in this film hasn't read a good book lately or ever the thought provoking dialogue is trite at best perhaps it lost something in the translation br br i love a good horror comedy this isn't one of them laugh i thought i'd never start squirm only when trying to think of a polite way to phrase my feedback of the film to the friend who recommended it br br rupert is <unk> good in the setting of this film but even he cannot resurrect it i only wish he had shot the director instead if the zombies br br for shame that the land that gave rise to the inferno should also give rise to this dante would be spinning in his grave\n",
      "\n",
      "Prediction: Negative\n"
     ]
    }
   ],
   "source": [
    "# Here we can see predictions for random reviews from the test set; hit Ctrl-Enter to sample:\n",
    "r = rand(1:length(xtst))\n",
    "println(reviewstring(xtst[r],ytst[r]))\n",
    "println(predictstring(xtst[r]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "stdin> i cannot recommend this movie\n",
      "\n",
      "Prediction: Negative\n"
     ]
    }
   ],
   "source": [
    "# Here the user can enter their own reviews and classify them:\n",
    "println(predictstring(str2ids(readline(stdin))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "@webio": {
   "lastCommId": null,
   "lastKernelId": null
  },
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "julia.ipynb",
   "provenance": [],
   "version": "0.3.2"
  },
  "kernelspec": {
   "display_name": "Julia 1.5.0",
   "language": "julia",
   "name": "julia-1.5"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.5.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
